{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "Conversion of Objax models to Tensorflow",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bxqOyLRaUkYW"
      },
      "source": [
        "# Conversion of Objax models to Tensorflow\n",
        "\n",
        "This tutorial demonstrates how to export models from Objax to Tensorflow and then export them into SavedModel format.\n",
        "\n",
        "SavedModel format could be read and served by [Tensorflow serving infrastructure](https://www.tensorflow.org/tfx/guide/serving) or by custom user code written in C++. Thus export to Tensorflow allows users to potentially run experiments in Objax and then serve these models in production (using Tensorflow infrastructure)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zYU-k0VsVRbL"
      },
      "source": [
        "## Installation and Imports\n",
        "\n",
        "First of all, let's install Objax and import all necessary python modules."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2dg5WFQApM1_",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a1cc8b58-615a-499e-c6d3-97277ae54da9"
      },
      "source": [
        "# install the latest version of Objax from github\n",
        "%pip --quiet install git+https://github.com/google/objax.git"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  Building wheel for objax (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oizwGGdIUMr0"
      },
      "source": [
        "import math\n",
        "import random\n",
        "import tempfile\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "import objax\n",
        "from objax.zoo.wide_resnet import WideResNet"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FJ65n3dQVvll"
      },
      "source": [
        "## Setup Objax model\n",
        "\n",
        "Let's make a model in Objax and create a prediction operation which we will be later converting to Tensorflow.\n",
        "\n",
        "In this tutorial we use randomly initialized model, so we don't need to wait for model training to finish. However conversion to Tensorflow would be the same if we train model first. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mBQ-xNSDWIsj"
      },
      "source": [
        "# Model\n",
        "model = WideResNet(nin=3, nclass=10, depth=4, width=1)\n",
        "\n",
        "# Prediction operation\n",
        "@objax.Function.with_vars(model.vars())\n",
        "def predict_op(x):\n",
        "  return objax.functional.softmax(model(x, training=False))\n",
        "\n",
        "predict_op = objax.Jit(predict_op)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pyfcbbJqZWFt"
      },
      "source": [
        "Now, let's generate a few examples and run prediction operation on them:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tgvMAfNjYw4A",
        "outputId": "7f5197b0-a537-4b5b-f7d8-2f18651d4957"
      },
      "source": [
        "input_shape = (4, 3, 32, 32)\n",
        "\n",
        "x1 = np.random.uniform(size=input_shape)\n",
        "y1 = predict_op(x1)\n",
        "print('y1:\\n', y1)\n",
        "\n",
        "x2 = np.random.uniform(size=input_shape)\n",
        "y2 = predict_op(x2)\n",
        "print('y2:\\n', y2)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "y1:\n",
            " [[0.06012213 0.07564961 0.13191961 0.12591475 0.08727679 0.16077745\n",
            "  0.10821478 0.07958559 0.08044666 0.09009264]\n",
            " [0.05957783 0.07606611 0.13554339 0.12834728 0.08717373 0.16128528\n",
            "  0.10691827 0.07874896 0.07711729 0.08922189]\n",
            " [0.06061443 0.07384451 0.13470736 0.12721166 0.0860136  0.16349576\n",
            "  0.11025441 0.0766369  0.07674664 0.09047476]\n",
            " [0.05972087 0.07695024 0.1373597  0.12381804 0.08652159 0.16483182\n",
            "  0.10871573 0.07637089 0.07585222 0.08985896]]\n",
            "y2:\n",
            " [[0.05968136 0.08011787 0.13363907 0.12838946 0.08562963 0.15783262\n",
            "  0.10989606 0.07535356 0.07875139 0.09070906]\n",
            " [0.0572416  0.07606035 0.13607582 0.12197609 0.08373585 0.16551377\n",
            "  0.11429026 0.07743792 0.0776429  0.0900254 ]\n",
            " [0.06138196 0.07201274 0.13394636 0.12132262 0.08225243 0.1682174\n",
            "  0.11442989 0.07763992 0.0776984  0.09109832]\n",
            " [0.05850162 0.07468063 0.13054986 0.12376051 0.08367112 0.16295902\n",
            "  0.11684892 0.07688387 0.08140761 0.09073676]]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_QdMusvnZjfb"
      },
      "source": [
        "## Convert a model to Tensorflow\n",
        "\n",
        "We use `Objax2Tf` object to convert Objax module into `tf.Module`.\n",
        "\n",
        "Internally `Objax2Tf` makes a copy of all Objax variables used by the provided module and converts `__call__` method of the provided Objax module\n",
        "into [Tensorflow function](https://www.tensorflow.org/api_docs/python/tf/function)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I3EHtu_8Zrsj",
        "outputId": "6d8d8350-8bb1-4b57-e4d9-75602ea3eb82"
      },
      "source": [
        "predict_op_tf = objax.util.Objax2Tf(predict_op)\n",
        "\n",
        "print('isinstance(predict_op_tf, tf.Module) =', isinstance(predict_op_tf, tf.Module))\n",
        "print('Number of variables: ', len(predict_op_tf.variables))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "isinstance(predict_op_tf, tf.Module) = True\n",
            "Number of variables:  39\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d-J40ksuf0IC"
      },
      "source": [
        "After module is converted we can run it and compare results between Objax and Tensorflow. Results are pretty close numerically, however they are not exactly the same due to implementation differences between JAX and Tensorflow."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jNlXguBwBfZj",
        "outputId": "2879556e-c0a5-4422-df93-d32535b2486a"
      },
      "source": [
        "y1_tf = predict_op_tf(x1)\n",
        "print('max(abs(y1_tf - y1)) =', np.amax(np.abs(y1_tf - y1)))\n",
        "\n",
        "y2_tf = predict_op_tf(x2)\n",
        "print('max(abs(y2_tf - y2)) =', np.amax(np.abs(y2_tf - y2)))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "max(abs(y1_tf - y1)) = 4.4703484e-08\n",
            "max(abs(y2_tf - y2)) = 2.2351742e-08\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WoytOXKMDjCH"
      },
      "source": [
        "## Export Tensorflow model as SavedModel\n",
        "\n",
        "Converting an Objax model to Tensorflow allows us to export it as [Tensorflow SavedModel](https://www.tensorflow.org/guide/saved_model).\n",
        "\n",
        "Discussion of details of SavedModel format is out of scope of this tutorial, thus we only provide an example showing how to save and load SavedModel. For more details about SavedModel please refert to the following Tensorflow documentation:\n",
        "\n",
        "* [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide\n",
        "* [tf.saved_model.save](https://www.tensorflow.org/api_docs/python/tf/saved_model/save) API call\n",
        "* [tf.saved_model.load](https://www.tensorflow.org/api_docs/python/tf/saved_model/load) API call"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hUWlR7212BzS"
      },
      "source": [
        "### Saving model as SavedModel\n",
        "\n",
        "First of all, let's create a new empty directory where model will be saved:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6HwkoRyLwi02",
        "outputId": "c6c36da5-0c1f-43b2-aef7-e1f893447efd"
      },
      "source": [
        "model_dir = tempfile.mkdtemp()\n",
        "\n",
        "%ls -al $model_dir"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "total 8\n",
            "drwx------ 2 root root 4096 Dec 17 23:28 \u001b[0m\u001b[01;34m.\u001b[0m/\n",
            "drwxrwxrwt 1 root root 4096 Dec 17 23:28 \u001b[30;42m..\u001b[0m/\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mJ28E6xbw1kJ"
      },
      "source": [
        "Then let's use `tf.saved_model.save` API to save our Tensorflow model.\n",
        "Since `Objax2Tf` is a subclass of `tf.Module`, instances of `Objax2Tf` class could be directly used with `tf.saved_model.save` API:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "S8a8I_ZAC83b",
        "outputId": "e4605384-e044-43f3-aaf8-bf59f18ed8e4"
      },
      "source": [
        "tf.saved_model.save(\n",
        "    predict_op_tf,\n",
        "    model_dir,\n",
        "    signatures=predict_op_tf.__call__.get_concrete_function(\n",
        "        tf.TensorSpec(input_shape, tf.float32)))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: /tmp/tmpppqtia2e/assets\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: /tmp/tmpppqtia2e/assets\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WEzWxSo1wB5M"
      },
      "source": [
        "Now we can list the content of `model_dir` and see files and subdirectories of SavedModel:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yaONSDb3FYsR",
        "outputId": "e6e0ab5b-c400-47f6-8ee2-f31b48f04eed"
      },
      "source": [
        "%ls -al $model_dir"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "total 612\n",
            "drwx------ 4 root root   4096 Dec 17 23:28 \u001b[0m\u001b[01;34m.\u001b[0m/\n",
            "drwxrwxrwt 1 root root   4096 Dec 17 23:28 \u001b[30;42m..\u001b[0m/\n",
            "drwxr-xr-x 2 root root   4096 Dec 17 23:28 \u001b[01;34massets\u001b[0m/\n",
            "-rw-r--r-- 1 root root 608158 Dec 17 23:28 saved_model.pb\n",
            "drwxr-xr-x 2 root root   4096 Dec 17 23:28 \u001b[01;34mvariables\u001b[0m/\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VUaZs6E4xD6B"
      },
      "source": [
        "### Loading exported SavedModel\n",
        "\n",
        "We can load SavedModel as a new Tensorflow object `loaded_tf_model`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TPOF9DhBFDZA",
        "outputId": "c9f404db-7376-43b6-b769-31cf1a593330"
      },
      "source": [
        "loaded_tf_model = tf.saved_model.load(model_dir)\n",
        "print('Exported signatures: ', loaded_tf_model.signatures)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:absl:Importing a function (__inference___call___2834) with ops with custom gradients. Will likely fail if a gradient is requested.\n",
            "WARNING:absl:Importing a function (__inference___call___5626) with ops with custom gradients. Will likely fail if a gradient is requested.\n",
            "WARNING:absl:Importing a function (__inference___call___4272) with ops with custom gradients. Will likely fail if a gradient is requested.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Exported signatures:  _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, args_0) at 0x7FF247EE5F28>})\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R_BvaRdAzVDi"
      },
      "source": [
        "Then we can run inference using loaded Tensorflow model `loaded_tf_model` and compare resuls with the model `predict_op_tf` which was converted from Objax:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Q95C5AUuF2hK",
        "outputId": "4314ff8f-8eb7-42b3-ddab-e254284eb774"
      },
      "source": [
        "loaded_predict_op_tf = loaded_tf_model.signatures['serving_default']\n",
        "\n",
        "y1_loaded_tf = loaded_predict_op_tf(tf.cast(x1, tf.float32))['output_0']\n",
        "print('max(abs(y1_loaded_tf - y1_tf)) =', np.amax(np.abs(y1_loaded_tf - y1_tf)))\n",
        "\n",
        "y2_loaded_tf = loaded_predict_op_tf(tf.cast(x2, tf.float32))['output_0']\n",
        "print('max(abs(y2_loaded_tf - y2_tf)) =', np.amax(np.abs(y2_loaded_tf - y2_tf)))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "max(abs(y1_loaded_tf - y1_tf)) = 0.0\n",
            "max(abs(y2_loaded_tf - y2_tf)) = 0.0\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}